{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2020 The Google Research Authors.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "'''\n",
    "# Introduction\n",
    "\n",
    "This is a demo of the SOFT top-k operator (https://arxiv.org/pdf/2002.06504.pdf). \n",
    "We demostrate the usage of the provided `Topk_custom` module in the forward and the backward pass.\n",
    "\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "# Set up\n",
    "'''\n",
    "import torch\n",
    "from torch.nn.parameter import Parameter\n",
    "import numpy as np\n",
    "import soft_ot as soft\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "matplotlib.rc('xtick', labelsize=20) \n",
    "matplotlib.rc('ytick', labelsize=20) \n",
    "\n",
    "torch.manual_seed(1)\n",
    "num_iter = int(1e2)\n",
    "k = 3\n",
    "epsilon=5e-2 # larger epsilon lead to smoother relaxation, and requires less num_iter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [],
   "source": [
    "soft_topk = soft.TopK_custom(k, epsilon=epsilon, max_iter=num_iter) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "======scores======\n",
      "[5, 2, 3, 4, 1, 6]\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "# Input the scores\n",
    "'''\n",
    "scores = [5,2,3,4,1,6] #input the scores here\n",
    "scores_tensor = Parameter(torch.FloatTensor([scores]))\n",
    "print('======scores======')\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "======topk scores======\n",
      "[0.15886909 0.84113085 0.63542354 0.36457655 0.9414631  0.05853688]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAADnCAYAAAAgo4yYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAUkUlEQVR4nO3de7BdZXnH8e/vRMQYabiEyHQooiEhUFQgEbCRS2SMMbSUIoxMJQjehgITYMBpKyrGSrGt0gSUVlAIxKJT2gJ1oBjFyFWGJoVRJBBAw0XkEq4x4c7TP961x+3m7LPXTvY+73t2fp+ZPStnrbXf/ZzMOc9+zrPed21FBGZmNvqGcgdgZra5cgI2M8vECdjMLBMnYDOzTJyAzcwycQI2syJJii4e1+aOd2M4AZvZIJiUO4CN8YbcAZiZtSOp1nljdT2DE7CZFWtoqN4f6a+++mqfI+kPJ2AzK5IkJ2Azs1zqtiDGKidgMyuWE7CZWSZOwGZmmTgBm5llIIlx48blDqOvnIDNrFiugM3MMnECNjPLQJITsJlZLk7AZmaZ+CKcmVkGbkGYmWXkBGxmlokTsJlZJk7AZmaZOAGbmWXgpchmZhm5AjYzy8QJ2MwsA88DNjPLyAnYzCwTX4QzM8vALQgzs4ycgM3MMnECNjPLZGhoKHcIfTXY352ZjVmNHnCdR5fj7ijpIkmPSHpR0hpJiyRt0+U475N0VfX8FyQ9KOkaSXPrjuEK2MyK1etZEJKmALcAk4GrgLuBfYCTgbmSZkXEkzXG+SvgfGA9cAXwMLAjcDjwIUmfi4izOo4TERv7vZiZ9c348eNjl112qXXunXfeuTIiZnY6T9IPgDnAgog4r2n/OcCpwDcj4vgOY2wBPAFsCewZEfc0HdsNuB14DdgmIl4caSy3IMysSJIYGhqq9ag53hRS8l0DfKPl8Jmkana+pAkdhtoWmAisbk6+ABGxClgNjAfe0ikmJ2AzK1aPe8Czq+2yiHit+UBErANuBt4M7NdhnMdJFfA0SVNb4p0GTAXuqNPKcAI2s2J1kYAnSVrR9Pj0MMPtWm1Xt3m5e6vttJFiitS3PZGUP1dKukTS2ZIuBVYCvwCOrPP9+SKcmRWp0YKoaW2NHvDEavtsm+ON/Vt3erGIuFzSI8B3gWOaDj0GXAz8stMY4ArYzAo2bty4Wo/RJulo4EfAjcBupNbFbsB1wNeB79UZxxWwmRWrxyvhGhXuxDbHG/ufGWmQqs97EfAzYH5TP/luSfNJrY4jJR0UET8ZaSxXwGZWpF7PggAaMxba9XgbF9Ta9Ygb5gBbANcPczHvNeCG6ssZnQJyBWxmxepxBby82s6RNNScPCVtBcwCNgC3dhhny2q7fZvjjf0vdQrIFbCZFauX09Ai4n5gGbAzaRZDs4XABGBpRKxvev3pkqa3nHtjtT1C0rta4t0TOAII4MedYnIFbGZF6tOnIp9AWop8rqSDgVXAvqQ5wquBM1rOX9UIp7EjIm6TdDFwHPC/kq4AHiAl9sOANwKLIuIXnYJxAjazYvX6bmgRcb+kmcCXgLnAPOA3wGJgYUQ8XXOoT5B6vccCHwS2Ap4DbgIujAjPgjCzsa0f9wOOiIdI1Wudc4cNoFqMsaR6bDQnYDMrkj+SyMwso0G/IbsTsJkVyxWwmVkGfZoFURQnYDMrlitgM7NM3AM2M8vAsyDMzDJyBWxmlokrYDOzDDwLwswsI1fAZmaZOAGbmWXgWRBmZhk5AZuZZeJpaGZmmbgCNjPLoPGpyIPMCdjMiuUK2MwsEydgM7NMnIDNzDJwD9jMLCNXwGZmmTgBm5ll4gRsZpaB7wVhZpaRE7CZWSaeBWFmlokrYDOzDNwDNjPLyAnYzCwTJ2Azs0x8Ec7MLAP3gM3MMhr0BDzY9b2ZjWmNKrjTo8sxd5R0kaRHJL0oaY2kRZK22Yj49pZ0maSHq7Eek3S9pGPqPN8VsJkVq9cVsKQpwC3AZOAq4G5gH+BkYK6kWRHxZM2xTgIWA08DVwO/BrYF9gDmAZd2GsMJ2MyK1YcWxPmk5LsgIs5rep1zgFOBs4Dja8Q1BzgX+CFwRESsazm+RZ1g3IIwsyI1bshe51FzvCnAHGAN8I2Ww2cC64H5kibUGO6fgOeBv2xNvgAR8XKdmFwBm1mxejwNbXa1XRYRrzUfiIh1km4mJej9gOvaDSJpD+BdwJXAU5JmAzOAAO4AlreO344TsJkVq4sWxCRJK5q+viAiLmg5Z9dqu7rNGPeSEvA0RkjAwHuq7ePAT4ADWo7/XNLhEXFfp6CdgM2sSF3OcFgbETM7nDOx2j7b5nhj/9YdxplcbT9BuvB2CHAT8FbgC8DRwNWS3hkRL400kHvAZlasfkxD64FG3hwHHBUR10TEcxFxL3AMsIJURX+400BdVcDbbbdd7LTTTt0Gm9Vdd92VO4SuTJw4sfNJhdl2221zh9C1e+65J3cIXZkxY0buELq2cuXKtRGx/aaM0eMecKPCbfdL1tj/TIdxGscfjYifNh+IiJB0FTCTNL3tuyMN1FUC3mmnnbj++uu7eUp2e+65Z+4QujJv3rzcIXTtqKOOyh1C1/bff//cIXRlxYoVnU8qjKQHejBGL0JpaLzrTmtzfGq1bdcjbh2nXaJ+utqO7xSQe8BmVqQ+tBeWV9s5koaaZypI2gqYBWwAbu0wzq2kKWs7S5oQEetbju9RbX/VKSD3gM2sWL3sAUfE/cAyYGfgxJbDC4EJwNLmhCppuqTpLeNsAL4NvAn4spoCkPRO4FjgFeA/OsXkCtjMitWHC2wnkJYinyvpYGAVsC9pjvBq4IyW81c1QmnZ/3nS9LNTgPdWc4jfChxOSsynVAl/RK6AzaxYvZ4FUSXFmcASUuI9DZhCuqfDfnXvAxERzwH7A39Puv/DScCfkqajfTAiFtcZxxWwmRVJEuPGjev5uBHxEHBczXPbZveI+C2pYm6tmmtzAjazYg36/YCdgM2sWE7AZmaZOAGbmWXgz4QzM8vICdjMLBN/LL2ZWSaugM3MMmh8JNEgcwI2s2K5AjYzy8QJ2MwsA7cgzMwycgVsZpaJE7CZWSZOwGZmmTgBm5ll4HtBmJll1I8bspfECdjMiuUK2MwsA7cgzMwy8kIMM7NMXAGbmWXiBGxmloHvBWFmlpETsJlZBp4FYWaWkROwmVkmTsBmZhlI8lJkM7NcXAGbmWXiBGxmlokTsJlZBl6IYWaWkStgM7NMBr0CHuzvzszGrEYLos6jy3F3lHSRpEckvShpjaRFkrbZhFgPkPSqpJD05brPcwVsZsXqdQtC0hTgFmAycBVwN7APcDIwV9KsiHiyyzG3Ai4BNgBv6ea5roDNrFiN+0F0enThfFLyXRARh0XE30TE+4F/BnYFztqIMBcDE4Gzu32iIqL+ydITwAPdvkgNk4C1fRi3n8ZazGMtXnDMo6Gf8b4tIrbf2CdPnTo1Fi9eXOvcQw45ZGVEzBzpnKr6vQ9YA0yJiNeajm0F/AYQMDki1td5XUl/DlwJzCd1FC4GzoqIz9V5flctiE35zxyJpBWd/vNKM9ZiHmvxgmMeDaXH2+MWxOxqu6w5+QJExDpJNwNzgP2A62rENhm4ELgyIr4j6dhuA3ILwsyK1LgXRJ1HTbtW29Vtjt9bbafVHO9CUg49vm4ArXwRzsyK1UUFPEnSiqavL4iIC1rOmVhtn20zRmP/1jXi+jhwKPCRiHisbpCtSknArf9RY8FYi3msxQuOeTQUHW8XCXjtaLVSJO0MLAIuj4h/35SxikjAw7xTFW+sxTzW4gXHPBpKjrcPS5EbFe7ENscb+5/pMM5FwPPACZsakHvAZlasHk9Du6fatuvxTq227XrEDXuTprI9US28CElBmgEBcEa178pOARVRAZuZDafHsyCWV9s5koaGmYY2i7SY4tYO41wKvHmY/VOBA4A7gJXA7Z0CcgI2s2L1MgFHxP2SlpGmmp0InNd0eCEwAfhm8xxgSdOr597dNM6CNrEeS0rAV9edB5ylBVGtvY42j0dzxDQSSdtJ+qSkKyTdJ+l5Sc9KuknSJyQV18qR9A+SrpP0UBXvU5Jul3SmpO1yx1eXpKObfjY+mTueZpKOkHSepBslPVfF+J3ccdUh6eDq5/nR6n4Ij0j6gaR5uWNrqNt+6DJJnwA8Dpwr6UpJZ0v6MXAqqfVwRsv5q6pHX+SsgJ8lXUls9dvRDqSGI4F/Ia2UWQ48CLwVOBz4FvAhSUdGN8sK++9U4P+AH5J+4CaQJph/Efi0pP0i4qF84XUm6Y+Ar5N+JrpaYz9KPge8mxTfw8D0vOHUI+kfgc+QYv5v0kq47YEZwEHANdmCa9Hre0FUVfBM4EvAXGAe6fd6MbAwIp7u6Qt2kDMBPxMRX8z4+t1YTZrzd3VL3+izwG3Ah0nJ+D/zhDesP4iIF1p3SjoL+Czwt/TgKm6/KP3mXQw8CfwXcHreiIZ1KimJ3QccyO96jMWS9ClS8r0E+HREvNRyfIssgbXR6wQMUBUex9U8t3YAEbEEWNJNLMX96VyiiPhxRHx/mOWLjwL/Wn150KgHNoLhkm+lMW9xapvjpVgAvJ/0i1JrXf5oi4jlEXFvYX/5tCVpS9LNZh5kmOQLEBEvj3pgI+hDC6IoOSvgLSUdDexE+gX7GXBDRLyaMaaN0fiBfSVrFPX9WbX9WdYoRiBpN+ArwOKIuEHS+3PHNCA+QGo1LAJek3QIsAfwAnBbRPw0Z3DDGcvJtY6cCXgHYGnLvl9JOi4irs8RULckvQE4pvry2pyxtCPpdFL/dCIwE3gfKfl+JWdc7VT/p0tJVdpnM4czaN5TbV8gTZHao/mgpBuAIyLiidEObDhjvbqtI1cL4mLgYFISngC8E/gmsDPwP5LenSmubn2F9EN8TUT8IHcwbZwOnAmcQkq+1wJzSvklG8YXgL2AYyPi+dzBDJjJ1fYzQAD7A1sB7wKWkaZQXZ4ntOENegsiSwKOiIVVX/WxiNgQEXdGxPHAOcB40pX6oklaAJxGuqP+/MzhtBURO1QXEnYgXSh8B3C7pL3zRvZ6kvYlVb1fK/HP4QHQ+H1/BTg0Im6KiN9GxM+BvyBdUDxQ0nuzRdjCCXh0NS5oHZA1ig4knUSatnIXMDsinsocUkfVm90VpEno25FW8xSjaj1cSppx8vnM4Qyqxj0Obo+INc0HImID0Pgrbp/RDGokTsCjq/Fn8YSsUYxA0imkFTR3kpJvcQtHRhIRD5DeOP5Y0qTc8TR5C2mN/m7AC/r9NfZnVudcWO0bbv64dda4F0K7m8005sCOH4VYOlKfPpSzJKUtRd6v2v4yaxRtSPprUt/3DuADETGWPnqm2R9W25JmnLwIfLvNsb1JfeGbSEnE7YmNcx2p97u7Wu6FUGlclPvV6IbV3liubusY9QRcTTF6sPUzl5Tusfn16svilnNK+jxp9cxK0kWsYtsOkqYBj0XEsy37h4C/I12MuWW0V/2MpLrgNuxSY0lfJCXgSyLiW6MZ1yCJiAckfZ+0qOhk0gdRAiBpDvBBUnVc5IyeQZSjAv4IcFo15eUBYB0wBTgEeBNpGeRXM8TVlqSPkZLvq8CNwIJh3pnXVCthSjAPOFvSTaRq5knS0ukDSRfhHgU+lS+8wSDpMOCw6ssdqu17JS2p/r02IkpbwXci6c3snGoe8O3A20nfx6vAJ1vfuHNyBdx7y0mfzbQX6fZvE0jvujeR5n8uLXBl0dur7TjSdK7hXE+XyxD76EfALqRpZ3uRPmJlPekC11Lg3JIr+DFkT+BjLfveUT0gFRhFJeCIeFjSDNJ0v0NJF7yfA74PnB0Rt+WMr9WgJ+CuPpbezGy07L777nHZZZfVOnevvfbq+LH0JSrtIpyZGdCXjyQqjhOwmRVr0FsQTsBmViwnYDOzTAY9AQ92g8XMrGCugM2sSGP9Pg91OAGbWbE8C8LMLBNXwGZmmTgBm5llsDn0gAe7wWJmVjBXwGZWrEG/CDfY352ZWcFcAZtZsQa9B+wEbGbFcgI2M8vAsyDMzKxvXAGbWbE8C8LMzPrCFbCZFWvQe8BOwGZWLCdgM7MMPAvCzGzASNpR0kWSHpH0oqQ1khZJ2qbm8ydI+qikyyTdLWm9pHWSVkg6TdIb68biCtjMitXrWRCSpgC3AJOBq4C7gX2Ak4G5kmZFxJMdhtkf+A7wFLAcuBLYBjgU+CpwuKSDI+KFTvE4AZtZsfrQgjiflHwXRMR5Ta9zDnAqcBZwfIcxHgWOBi6PiJeaxjgd+AnwJ8CJwNc6BeMWhJltFqrqdw6wBvhGy+EzgfXAfEkTRhonIu6IiH9rTr7V/nX8LukeVCcmJ2AzK1bjQlynR02zq+2yiHit+UCVPG8G3gzstwkhv1xtX6lzshOwmRWpbvLtIgHvWm1Xtzl+b7Wdtglhf7zaXlvnZPeAzaxYXSTXSZJWNH19QURc0HLOxGr7bJsxGvu3rvuizSSdBMwF7gAuqvMcJ2AzK1YXCXhtRMzsZywjkXQ4sIh0ge7DEfFyh6cATsBmVrAez4JoVLgT2xxv7H+mm0ElHQZ8D3gcmB0Rv6z7XPeAzWxzcU+1bdfjnVpt2/WIX0fSkcDlwGPAgRFxT4en/B5XwGZWrB5XwMur7RxJQ80zISRtBcwCNgC31ozto8AlwK/psvJtcAVsZkXq9SyIiLgfWAbsTFoo0WwhMAFYGhHrm2KYLmn6MLF9DLgUeBA4YGOSL7gCNrOC9WEl3AmkpcjnSjoYWAXsS5ojvBo4o+X8VY1QmmKaTZrlMESqqo8bJs5nImJRp2CcgM2sWL1OwBFxv6SZwJdIU8bmAb8BFgMLI+LpGsO8jd91Dz7e5pwHSLMiRuQEbGablYh4CDiu5rmveweIiCXAkl7E4gRsZsXy/YDNzKwvXAGbWZE2h0/EcAI2s2INegJ2C8LMLBNXwGZWLFfAZmbWF66AzaxYg14BOwGbWbEGPQG7BWFmlokrYDMr0uYwD9gVsJlZJk7AZmaZuAVhZsUa9BaEE7CZFWvQE7BbEGZmmbgCNrNiuQI2M7O+cAVsZsVyBWxmZn3hCtjMirQ5rIRzAjazYg16AnYLwswsE1fAZlYsV8BmZtYXroDNrFiDXgErInLHYGb2OpKuBSbVPH1tRMztZzz94ARsZpaJe8BmZpk4AZuZZeIEbGaWiROwmVkmTsBmZpn8P40gPBhzg52oAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "'''\n",
    "# Forward pass. \n",
    "The goal of the forward pass is to identify the scores that belongs to top-k. \n",
    "The `soft_topk` object returns a smoothed indicator function: The entries are close to 1 for top-k scores, and close to 0 for non-top-k scores.\n",
    "The smoothness is controled by hyper-parameter `epsilon`.\n",
    "'''\n",
    "A = soft_topk(scores_tensor)\n",
    "indicator_vector = A.data.numpy()\n",
    "print('======topk scores======')\n",
    "print(indicator_vector[0,:])\n",
    "\n",
    "plt.imshow(indicator_vector, cmap='Greys')\n",
    "# plt.axis('off')\n",
    "plt.yticks([])\n",
    "plt.xticks(range(len(scores)), scores)\n",
    "plt.colorbar()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "======w.r.t score grad======\n",
      "[[ 0.10541585 -0.04306151 -0.07465184 -0.07465144  0.04347444  0.04347456]]\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "# Backward Pass\n",
    "\n",
    "The goal of training is to push the scores that should have been top-k to really be top-k. \n",
    "For example, in neural kNN, we want to push the scores with the same labels as the query sample to be top-k.\n",
    "In this demo, we mimick the loss function of neural kNN. \n",
    "`picked` is the scores ids with the same label as the query sample. Our training goal is to push them to be top-k.\n",
    "'''\n",
    "\n",
    "picked = [1,2,3]\n",
    "loss = 0\n",
    "for item in picked:\n",
    "    loss += A[0,item]\n",
    "loss.backward()\n",
    "A_grad = scores_tensor.grad.clone()\n",
    "print('======w.r.t score grad======')\n",
    "print(A_grad.data.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAE4CAYAAABlvCnWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAcTElEQVR4nO3de3hU9b3v8fePkJtBAyIWwWDAoykSIIFARQyhUox4QaHb7bUiHrzU67YaFAotUnS7C8etm/08lIMtsNG2igKNPfuA1hsgeoAEDhc9YC1CglgC2yCJEwjJ9/yxwkhI0ICT+U2Sz+t5eIb5zczKZ02ST9b81lozzswQEZHoa+c7gIhIW6UCFhHxRAUsIuKJClhExBMVsIiIJypgERFP2p/Mnc866yxLT09vpigiIq1TUVHRPjPrcvz4SRVweno669evj1wqEZE2wDm3s7FxTUGIiHiiAhYR8UQFLCLiyUnNAYucSHV1NaWlpVRVVfmOIuJNUlIS5557LvHx8U26vwpYIqK0tJTTTz+d9PR0nHO+44hEnZmxf/9+SktL6dmzZ5MeoykIiYiqqio6d+6s8pU2yzlH586dT+pVoApYIkblK23dyf4OqIBF2pCnnnrqlG5rTp9++imZmZnfer9ly5bx4YcfntSyT+Ux3yQ9PZ19+/ZFbHkqYJFGHDlyxHeEZhGLBdxUsVDAkaYCFj82vQz/mgnTOgaXm17+TourrKzkqquuon///mRmZvLSSy8BsG7dOi655BL69+/P4MGDOXjwIFVVVYwfP56+ffuSnZ3N22+/DcCCBQsYPXo0l112GSNGjKCyspI77riDwYMHk52dzZ/+9CcAtm7dyuDBg8nKyqJfv358/PHH9bIsXryYn/3sZwA899xz9OrVC4C//e1vDB06FICioiLy8vIYOHAg+fn57Nmzp94yampq6NmzJ2ZGeXk5cXFxrFy5EoBhw4bx8ccfs3btWoYMGUJ2djaXXHIJ27ZtC6/H2LFjueKKK7jggguYOHEiAI8//jihUIisrCxuueWWel+vsdueeeYZMjMzyczM5Nlnn230ee/QoQMPP/wwffr0YcSIEZSVlQEwfPjw8Fmz+/bt4+hbGJzouaupqeHOO++kT58+XH755YRCoXpfZ82aNRQWFlJQUEBWVhaffPIJGzdu5OKLL6Zfv36MGTOGL7744pQfM3z4cB566CGysrLIzMxk7dq1ja7vUaFQiFGjRjFv3rxvvN+3MrMm/xs4cKCJNObDDz9s+p3/70tmM75n9sszvv4343vB+Cl65ZVXbMKECeHr5eXldujQIevZs6etXbvWzMwOHDhg1dXVNmvWLBs/fryZmX300UeWlpZmoVDI5s+fb927d7f9+/ebmdmkSZNs0aJFZmb2xRdf2AUXXGAVFRV2//332wsvvGBmZocOHbKvvvqqXpY9e/ZYTk6OmZn9+Mc/tpycHCstLbUFCxbY448/bocPH7YhQ4bY3r17zczsj3/8YzjPsfLz823Lli322muvWU5Ojs2YMcOqqqosPT293vqYmb3xxhs2duxYMzObP3++9ezZ08rLyy0UClmPHj1s165dZmaWkpJywufw2NvWr19vmZmZVlFRYQcPHrSLLrrIiouLGzwGCD8XTzzxhN13331mZpaXl2fr1q0zM7OysjI777zzzMwafe527NhhcXFxtmHDBjMzu/7668PP+7HGjRtnixcvDl/v27evvfPOO2ZmNnXqVHvooYdO+TF5eXnhn593333X+vTp0+hzdN5559mOHTtsxIgRtnDhwkbv09jvArDeGulUHYYm0ffmdKiuv4VDdSgY7/ePp7TIvn378sgjj/DYY49x9dVXk5uby+bNmznnnHMYNGgQAGeccQYAq1ev5oEHHgDg+9//Pueddx7bt28HYOTIkZx55pkAvP766xQWFjJr1iwgONJj165dDBkyhCeffJLS0lLGjh3LBRdcUC9L165dqaio4ODBg5SUlHDzzTezcuVKVq1axdixY9m2bRtbtmxh5MiRQLD1d8455zRYp9zcXFauXMmOHTuYNGkS8+bNIy8vL7w+Bw4cYNy4cXz88cc456iurg4/dsSIEaSmpgJw0UUXsXPnTtLS0pr8fK5evZoxY8aQkpICwNixY1m1ahXZ2dn17teuXTtuuOEGAG699VbGjh37jcs90XPXs2dPsrKyABg4cCCffvrpNy7nwIEDlJeXk5eXB8C4ceO4/vrrv9NjbrrpJiB4hfHll19SXl5Ox44dGyzn2muvZeLEiQ1eRZwKTUFI9B0oPbnxJrjwwgspLi6mb9++TJkyhenTp5/Sco4WDgSvDl999VU2btzIxo0b2bVrF7179+bmm2+msLCQ5ORkrrzySt56660Gy7nkkkuYP38+GRkZ5ObmsmrVKt5//32GDh2KmdGnT5/wcjdv3szrr7/eYBnDhg1j1apVrF27liuvvJLy8nLeeecdcnNzAZg6dSo//OEP2bJlC6+99lq9w58SExPD/4+Li4vanPbRowDat29PbW0tQL1cJ3rufOU91vFHMDjnyM/PJysriwkTJoTHhw4dyvLly7EIfKCxCliiL/Xckxtvgs8++4zTTjuNW2+9lYKCAoqLi8nIyGDPnj2sW7cOgIMHD3LkyBFyc3N58cUXAdi+fTu7du0iIyOjwTLz8/OZPXt2+Bdtw4YNQDCX26tXLx588EGuvfZaNm3a1OCxubm5zJo1i2HDhoXnmRMTE0lNTSUjI4OysjLef/99IDiLcOvWrQ2WMXjwYNasWUO7du1ISkoiKyuLuXPnMmzYMCDYouvevTsQzPs2RXx8fL0t5RPdlpuby7Jly/jqq6+orKxk6dKl4eI/Vm1tLa+88goAv//977n00kuB4GiBoqIigPDtTX3uTuT000/n4MGDAKSmptKpUydWrVoFwKJFi8Jbtqf6mKP7DVavXk1qaiqpqamsWLGCjRs38vzzz4fvN336dDp16sR9993X5OwnogKW6BvxC4hPrj8WnxyMn6LNmzeHd+488cQTTJkyhYSEBF566SUeeOAB+vfvz8iRI6mqquLee++ltraWvn37csMNN7BgwYJ6W2BHTZ06lerqavr160efPn2YOnUqAC+//DKZmZlkZWWxZcsWbrvttgaPzc3NpaSkhGHDhhEXF0daWlq4nBISEnjllVd47LHH6N+/P1lZWaxZs6bBMhITE0lLS+Piiy8OL/PgwYP07dsXgIkTJzJp0iSys7ObvMV411130a9fv0ZfPh9724ABA7j99tsZPHgwP/jBD5gwYUKD6QcIXjGsXbuWzMxM3nrrLX7xi+B7+OijjzJnzhyys7PrHbbVlOfuRG688UZmzpxJdnY2n3zyCQsXLqSgoIB+/fqxcePG8Nc+1cckJSWRnZ3NPffcw29/+9tvzPLcc88RCoXCOzhPlTuZzeicnBzT+wFLYz766CN69+7d9AdsejmY8z1QGmz5jvjFKc//ij8dOnSgoqLCd4zvbPjw4cyaNYucnJzvvKzGfhecc0Vm1mDh2gknfvT7RxWutHkqYBE5Za1h6xfgnXfe8fJ1NQcsIuKJClgiJhKH5Yi0ZCf7O6AClohISkpi//79KmFps6zu/YCTkpKa/BjNAUtEnHvuuZSWlobfC0CkLTr6iRhNpQKWiIiPj2/ypwCI+PTGG29w6aWXkpyc/O13bmaaghCRNuO9997j8ssvZ+7cub6jACpgEWlDJk2aBMCMGTM4fPiw5zQqYBFpI4qLi8PvUVxVVcXChQs9J1IBi0gb8cILL1BTUwMEn3jyu9/9znMiFbCItBFTp04Nn/E2e/ZsFi9e7DcQKmARaSM6derEkCFDABgwYMBJHS7WXFTAIiKeqIBFRDxRAYuIeKICFhHxRAUsIuKJClhE2oxlG3YDcM3s1Qx9+q3wdV9UwCLSJizbsJtJSzYDYMDu8hCTlmz2WsIqYBFpE2au2EaoOjgTzsUFbwQZqq5h5opt3jLp7ShFpE34rDwEQLe7nye+Y9cG4z5oC1hE2oRuHYP3/z22fI8d90EFLCJtQkF+BsnxcfXGkuPjKMjP8JRIUxAi0kZcl90dCOaCPysP0a1jMgX5GeFxH2K2gFeuXMngwYNP6gPuRES+yXXZ3b0W7vFicgrigw8+IC8vj+eff953FBGRZhOTBTx58mQAnnjiCaqrqz2nERFpHjFXwJs2beKDDz4Ago8NefHFFz0nEhFpHjFXwD//+c85dOgQABUVFUyZMiX8MSIiIq1JTBXw9u3b+ctf/kJtbW147MCBAyxZssRjKhGR5hFTBTxt2rQGc74VFRVMnjwZM/OUSkSkecRMAe/atYulS5c2Ot3w+eefs3z5cg+pRESaT8wU8IwZM04411tRUcHjjz+urWARaVViooD37t3L/Pnzv/GQs61bt7Jq1aoophIRaV4xcSZchw4dGDNmDKHQ1+9K9Oc//5mcnBy6dv36jTO6devmI56ISLNwJ/OyPicnx9avX9+Mcb7mnGP9+vUMHDgwKl9PRKS5OOeKzCzn+PGYmIIQEWmLVMAiIp6ogEVEPFEBi4h4ogIWEfFEBSwi4okKWETEExWwiIgnKmAREU9UwCIinqiARUQ8UQGLiHiiAhYR8UQFLCLiiQpYRMQTFbCIiCcqYBERT1TAIiKeqIBFRDxRAYuIeBKTBbxsw24Arpm9mqFPvxW+LiLSmsRcAS/bsJtJSzYDYMDu8hCTlmxWCYtIqxNzBTxzxTZC1TUAtEtMASBUXcPMFdt8xhIRibj2vgMc77PyEABpP3uVdvGJDcZFRFqLmNsC7tYxGaBe+R47LiLSWsRcARfkZ5AcH1dvLDk+joL8DE+JRESaR8xNQVyX3R0I5oI/Kw/RrWMyBfkZ4XERkdYi5goYghJW4YpIaxdzUxAiIm2FClhExBMVsIiIJypgERFPVMAiIp6ogEVEPFEBi4h4ogIWEfFEBSwi4okKWETEExWwiIgnKmAREU9UwCIinqiARUQ8UQGLiHiiAhYR8UQFLCLiiQpYRMQTFbCIiCcqYBERT1TAIiKeqIBFRDxRAYuIeKICFhHxRAUsIuKJClhExBMVsIiIJypgERFPVMAiIp6ogEVEPGnvO4BIS7KucC5pxTM528rY67pQMqCAQaPv9h1LWigVsEgTrSucS2bRFJLdYXDQlTJSi6awDlTCcko0BSHSRGnFM0l2hzEzzAyAZHeYtOKZnpNJS6UCFmmis60MgEffqOLRN6qOGd/nK5K0cCpgkSba67oA8Mz71TzzfvUx42f5iiQtnApYpIlKBhQQsoR6YyFLoGRAgadE0tJpJ5xIEw0afTfrALgHgM/pQslAHQUhp84d3ZnQFDk5ObZ+/fpmjCMS+5xzAJzM7460bc65IjPLOX5cUxAiIp6ogEVEPFEBi4h4ogIWEfFEBSwi4okKWETEExWwiIgnKmAREU9UwCIinqiARUQ8UQGLiHiiAhYR8UQFLNIEq1evZtSoUYwaNSo8dvT6e++95zGZtGR6O0qRJtixYwdvvvkm1dVfvxH78uXLiY+P5+abb2bo0KEe00lLpS1gkSa46aab6NSpU4PxM888k5tuuslDImkNVMAiTdC+fXumTZtGSkpKeCwlJYVp06bRvr1eSMqpUQGLNNEdd9xBfHx8+HpCQgLjx4/3mEhaOhWwSBMlJiYyefJkTjvtNE477TQmT55MYmKi71jSgukjiUROQmVlJV27dsU5x549e+pNSYicyIk+kkiTVyInISUlhd/85jfh/4t8FypgkZN0yy23+I4grYTmgEVEPFEBi4h4ogIWEfFEBSwi4okKWETEExWwiIgnKmAREU9UwCIinqiARUQ8UQGLiHiiU5ElotYVziWteCZnWxl7XRdKBhQwaPTdvmOJxCQVsETMusK5ZBZNIdkdBgddKSO1aArrQCUs0ghNQUjEpBXPJNkdxsyoOBy8zWmyO0xa8UzPyURikwpYIuZsKwPg3/7PYb436+Ax4/t8RRKJaSpgiZi9rgsAW8tq+ar62PGzPCUSiW0qYImYkgEFhCyh3ljIEigZUOApkUhsUwFLxAwafTdbBs7gK5IA+JwubBk4QzvgRE5AR0FIRA0afTen/bkIiufRddpf6eo7kEgM0xawiIgnKmAREU9UwCIinqiARUQ8UQGLiHiiAhYR8UQFLCLiiQpYRMQTFbCIiCcqYBERT1TAIiKeqIBFRDxRAUtEbN++nZ49e9KjRw/mzZsHQI8ePejRowdTpkzxnE4kNund0CQiUlJS2L17N9XVX78Te0lJCQkJCSQkJHzDI0XaLm0BS0R0796dG2+8kfbt6/9Nj4+P56GHHvKUSiS2qYAlYqZNm1avgJOSkrj33ntJTU31mEokdqmAJWJ69erFqFGjaNcu+LFyzlFQoI8jEjkRFbBE1K9+9SsSExOJj4/n9ttvp0uXLr4jicQsFbBEVJ8+fbj00kupqalh8uTJvuOIxDQdBSERN2fOHDZv3sy5557rO4pITFMBS8Sdf/75nH/++b5jiMQ8TUGIiHiiAhYR8UQFLCLiiQpYRMQTFbCIiCcqYBERT1TAIiKeqIBFRDxRAYuIeKICFhHxRKciR9m6wrmkFc/kbCtjr+tCyYACBo2+23csEfFABRxF6wrnklk0hWR3GBx0pYzUoimsA5WwSBukKYgoSiueGZQvsH1/DQDJ7jBpxTN9xhIRT1TAUXS2lQGw6e81ZPx75THj+3xFEhGPVMBRtNcFnw5RXXP8+Fke0oiIbyrgKCoZUEDI6n9Ee8gSKBmgz00TaYtUwFE0aPTdbBk4g310AuBzurBl4AztgBNpo3QURJQNGn03Rd1zYF4OXaf9la6+A4mIN9oCFhHxRAUsIuKJClhExBMVsIiIJyrgKNu0aROFhYUA/OEPf6C0tNRzIhHxRQUcZffffz9PP/00AOPGjWP27NmeE4mILyrgKHvkkUdISAhOxoiLi+OnP/2p50Qi4osKOMquueYaOnfujHOOq6++mvT0dN+RRMQTFXCUtWvXjieffBKA6dOne04jIj6pgD248cYb2blzJ7179/YdRUQ8UgF7EBcXR1pamu8YIuKZClhExBMVsIiIJypgERFPVMAiIp6ogEVEPFEBi4h4ogIWEfFEBSwi4okKWETq2bNnD5WVlb5jtAkqYBEJq6ys5MILL2TixIm+o7QJKmARCZszZw5Hjhxh/vz57Nu3z3ecVk8FLCIAHDp0iKeeeoqqqipqa2v59a9/7TtSq6cCFhEACgsL+fLLL2nXrh1mxpw5c6ipqfEdq1VTAYsIAFdccQXz58+ntraW9PR0li5dSlxcnO9YrZoKWEQAOP300/nJT34CQF5eHj/60Y88J2r9VMAiIp6ogEVEPFEBi4h4ogIWEfFEBSwi4okKWETClm3YDcAf1u5i6NNvha9L81ABiwgQlO+kJZvD13eXh5i0ZLNKuBmpgEUEgJkrthGqDs58c/FJAISqa5i5YpvPWK1ae98BRCQ2fFYeAqDrbc8Q3zmtwbhEnraARQSAbh2TAUg850LaJSQ3GJfIUwGLCAAF+Rkkx9d/74fk+DgK8jM8JWr9NAUhIgBcl90dCOaCPysP0a1jMgX5GeFxiTwVsIiEXZfdXYUbRZqCEBHxRAUsIuKJClhExBMVsIiIJypgERFPVMAiIp6ogEVEPFEBi4h4ogIWEfFEBSwi4okKWETEExWwiIgnKmAREU9UwCIinqiARUQ8UQGLiHiiAhYR8UQFLCLiiQpYRMQTFbCIiCcqYBERT1TAIiKeqIBFRDxRAYuIeKICFhHxRAUsIuKJClhExBMVsIiIJypgERFPVMAiIp6ogEVEPFEBi4h4ogIWEfFEBSwi4okzs6bf2bkyYGfzxWngLGBfFL9eNLXmdQOtX0un9Yus88ysy/GDJ1XA0eacW29mOb5zNIfWvG6g9WvptH7RoSkIERFPVMAiIp7EegH/T98BmlFrXjfQ+rV0Wr8oiOk5YBGR1izWt4BFRFotFbCIiCcxVcDOuU+dc3aCf5/7zvddOOc6O+cmOOeWOuf+6pwLOecOOOdWO+f+u3Mupr4XJ8s59y/OuTedcyV16/ZfzrkNzrlfOuc6+87XHJxztx7z8znBd57vwjn3D8652c65Vc65L+vW6QXfuSLNOTei7nfwc+fcIefcZ865Fc65K33kae/ji36LA8CzjYxXRDtIhF0PzAH2AG8Du4DvAWOB54FRzrnrreVOyj8MFANvAHuBFOBiYBpwl3PuYjMr8RcvspxzacC/E/xcdvAcJxKmAP0J1qcU+L7fOJHnnPs1UECwfoUEJ2J0AQYCw4H/jHamWCzgcjOb5jtEM9gOjAb+l5nVHh10zk0G1gI/JijjV/3E+87OMLOq4wedc08Ck4FJwL1RT9UMnHMOmA/sB5YAj/pNFBEPExTTX4E8go2EVsM5dydB+S4E7jKzw8fdHu8jV4t+2duSmNlbZvbaseVbN/458Ju6q8OjHixCGivfOi/XXV4QrSxR8CBwGTAeqPScJSLM7G0z+7gFvwI7IedcIvAkwavOBuULYGbVUQ9GbG4BJzrnbgV6EPxwbwJWmlmN31jN6ug3/4jXFM3jmrrLTV5TRIhzrjfwNPCcma10zl3mO5N8q5EEUw3PArXOuauATKAKWGtm7/sKFosF3BVYdNzYDufceDN710eg5uScaw/cVnd1uc8skeCce5RgTjQVyAEuJSjfp33mioS679Uigi2pyZ7jSNMNqrusAjYQlG+Yc24l8A9mVhbtYLE2BTEfGEFQwilAX2AukA78b+dcf3/Rms3TBD8Q/2lmK3yHiYBHgV8C/0RQvsuBy338cDeDXwDZwO1mFvIdRprs7LrLAsCAXOB0oB/wOjAMWOwjWEwVsJk9UTdX+ncz+8rMtpjZPcAzQDLBHvVWwzn3IPAI8P+An3iOExFm1tXMHMEf0bFAL2CDc26A32TfjXPuBwRbvf/D50tWOSVHe+4IMNrMVptZhZltBsYQ7HzMc84N8RUs1h3dSTXMa4oIcs7dDzwHfAj80Mz+y3OkiKr7I7oUuBzoDPyH50inrG7q4T8IjmSZ6jmOnLzyussNZvbpsTeY2VfA0Veeg6MZClpOAR99+ZriNUWEOOf+CZgNbCEo3xZ9ksk3MbOdBH9k+jjnzvKd5xR1AC4EegNVx54gRDDdAjCvbqyxY9jFr211l+UnuP2LusvkKGSpJxZ3wjXm4rrLv3lNEQHOuccI5n03AiPNrDV/6sBR3eouW+qRLIeA357gtgEE88KrCX7RNT0Re94kmPu9yDnX7vhDQfl6p9yO6MaKoQKuO7xnl5lVHjeeTnDGEUCLPjXSOTcVmA4UEeyYahXTDs65C4G/m9mB48bbAb8i2Amyxsy+aOzxsa5uh1ujpxo756YRFPBCM3s+mrmkacxsp3PuNYIToR4C/vXobc65y4F8gq3jqB+FFDMFDNwAPFJ3SMhO4CBwPnAVkERwmuAsf/G+G+fcOILyrQFWAQ8GJ1TV86mZLYhytEi4Evhn59xqgq2I/QSnWecR7IT7HLjTXzz5Ns6564Dr6q52rbsc4pxbUPf/fWbWks/4u4/gD+UzdccBbwB6EqxzDTDh+A2IaIilAn4byCB4koYSzPeWE7y0WwQsauFn6fSsu4wjOESrMe8CC6KSJrL+Avw3gsPOsoGOBCfRbCf43v1ba9nab8WygHHHjfWq+wfBRlGLLWAzK3XODSQ4lHA0wQ79L4HXgH82s7U+cukN2UVEPGkpR0GIiLQ6KmAREU9UwCIinqiARUQ8UQGLiHiiAhYR8UQFLCLiiQpYRMQTFbCIiCcqYBERT/4/MXeDIxO0T2EAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "'''\n",
    "# Visualization of the Grad\n",
    "'''\n",
    "\n",
    "x = scores\n",
    "grad = A_grad.numpy()[0,:]\n",
    "grad = grad/np.linalg.norm(grad)\n",
    "plt.figure(figsize=(len(scores),5))\n",
    "plt.scatter(range(len(x)), x)\n",
    "picked_scores = [x[item] for item in picked]\n",
    "plt.scatter(picked, picked_scores, label='scores we want to push to smallest top-k')\n",
    "for i, item in enumerate(x):\n",
    "    plt.arrow(i, item, 0, grad[i], head_width=abs(grad[i])/4, fc='k')\n",
    "plt.xticks(range(len(x)), x)\n",
    "plt.yticks([])\n",
    "plt.xlim([-0.5, len(scores)-0.5])\n",
    "plt.ylim([min(scores)-1, max(scores)+1])\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 150,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# clear the grad before rerun the forward-backward code\n",
    "scores_tensor.grad.data.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
